import json
import logging
import os
from typing import List
import jsonlines
from tqdm import tqdm
from transformers import PreTrainedTokenizer


logger = logging.getLogger(__name__)


class InputExample(object):
    """A single training/test example"""
    def __init__(self, example_id, tokens, triggerL, triggerR, label=None):
        self.example_id = example_id
        self.tokens = tokens
        self.triggerL = triggerL
        self.triggerR = triggerR
        self.label = label


class InputFeatures(object):
    def __init__(self, example_id, input_ids, input_mask, segment_ids, maskL, maskR, label):
        self.example_id = example_id
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.maskL = maskL
        self.maskR = maskR
        self.label = label


class DataProcessor(object):
    """Base class for data converters for multiple choice data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_test_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the test set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()


class LEVENProcessor(DataProcessor):
    """Processor for the LEVEN data set."""
    def get_train_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {} train".format(data_dir))
        return self._create_examples(os.path.join(data_dir, 'train.jsonl'), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {} dev".format(data_dir))
        return self._create_examples(os.path.join(data_dir, 'valid.jsonl'), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {} test".format(data_dir))
        return self._create_examples(os.path.join(data_dir, 'test.jsonl'), "test")

    def get_labels(self):
        """See base class."""
        return list(json.load(open('./utils/pure_event2id.json', encoding='utf-8')).keys())

    @staticmethod
    def _create_examples(fin, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        input_data = jsonlines.open(fin)

        for data in input_data:
            for event in data['events']:
                for mention in event['mention']:
                    e_id = "%s-%s" % (set_type, mention['id'])
                    examples.append(
                        InputExample(
                            example_id=e_id,
                            tokens=data['content'][mention['sent_id']]['tokens'],
                            triggerL=mention['offset'][0],
                            triggerR=mention['offset'][1],
                            label=event['type'],
                        )
                    )
            for nt in data['negative_triggers']:
                e_id = "%s-%s" % (set_type, nt['id'])
                examples.append(
                    InputExample(
                        example_id=e_id,
                        tokens=data['content'][nt['sent_id']]['tokens'],
                        triggerL=nt['offset'][0],
                        triggerR=nt['offset'][1],
                        label='None',
                    )
                )

        return examples


class LEVENInferProcessor(DataProcessor):
    def get_test_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {} test".format(data_dir))
        return self._create_examples(os.path.join(data_dir, 'test.jsonl'), "test")

    def get_labels(self):
        """See base class."""
        return list(json.load(open('./utils/pure_event2id.json', encoding='utf-8')).keys())

    @staticmethod
    def _create_examples(fin, set_type):
        """Creates examples for the test sets."""
        examples = []
        input_data = jsonlines.open(fin)
        for data in input_data:
            for mention in data['candidates']:
                e_id = "%s-%s" % (set_type, mention['id'])
                examples.append(
                    InputExample(
                        example_id=e_id,
                        tokens=data['content'][mention['sent_id']]['tokens'],
                        triggerL=mention['offset'][0],
                        triggerR=mention['offset'][1],
                        label='None',
                    )
                )
        return examples


def convert_examples_to_features(
    examples: List[InputExample],
    label_list: List[str],
    max_length: int,
    tokenizer: PreTrainedTokenizer,
    pad_token_segment_id=0,
    pad_on_left=False,
    pad_token=0,
    mask_padding_with_zero=True,
) -> List[InputFeatures]:
    """
    Loads a data file into a list of `InputFeatures`
    """
    label_map = {label: i for i, label in enumerate(label_list)}

    features = []
    for ex_index, example in enumerate(tqdm(examples, desc='convert examples to features')):
        if ex_index % 10000 == 0:
            logger.info("Writing example %d of %d" % (ex_index, len(examples)))

        # leven is in Chinese, therefore, use "".join() instead of " ".join()
        textL = tokenizer.tokenize("".join(example.tokens[:example.triggerL]))

        textR = tokenizer.tokenize("".join(example.tokens[example.triggerL:example.triggerR]))
        textR += ['[unused1]']
        textR += tokenizer.tokenize("".join(example.tokens[example.triggerR:]))

        maskL = [1.0 for i in range(0, len(textL)+1)] + [0.0 for i in range(0, len(textR)+2)]
        maskR = [0.0 for i in range(0, len(textL)+1)] + [1.0 for i in range(0, len(textR)+2)]

        if len(maskL) > max_length:
            maskL = maskL[:max_length]
        if len(maskR) > max_length:
            maskR = maskR[:max_length]

        inputs = tokenizer.encode_plus(textL + ['[unused0]'] + textR,
                                       add_special_tokens=True,
                                       max_length=max_length,
                                       return_token_type_ids=True,
                                       return_overflowing_tokens=True)

        if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
            logger.info(
                "Attention! you are cropping tokens."
            )

        input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
        assert len(input_ids) == len(maskL)
        assert len(input_ids) == len(maskR)
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
        # Zero-pad up to the sequence length.
        padding_length = max_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
            token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
            maskL = ([0.0] * padding_length) + maskL
            maskR = ([0.0] * padding_length) + maskR
        else:
            input_ids = input_ids + ([pad_token] * padding_length)
            attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
            token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
            maskL = maskL + ([0.0] * padding_length)
            maskR = maskR + ([0.0] * padding_length)

        assert len(input_ids) == max_length
        assert len(attention_mask) == max_length
        assert len(token_type_ids) == max_length

        label = label_map[example.label]

        if ex_index < 0:    # dont print
            logger.info("*** Example ***")
            logger.info("example_id: {}".format(example.example_id))
            logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
            logger.info("attention_mask: {}".format(" ".join(map(str, attention_mask))))
            logger.info("token_type_ids: {}".format(" ".join(map(str, token_type_ids))))
            logger.info("maskL: {}".format(" ".join(map(str, maskL))))
            logger.info("maskR: {}".format(" ".join(map(str, maskR))))
            logger.info("label: {}".format(label))

        features.append(InputFeatures(example_id=example.example_id,
                                      input_ids=input_ids,
                                      input_mask=attention_mask,
                                      segment_ids=token_type_ids,
                                      maskL=maskL,
                                      maskR=maskR,
                                      label=label))

    return features


processors = {"leven": LEVENProcessor, "leven_infer": LEVENInferProcessor}


MULTIPLE_CHOICE_TASKS_NUM_LABELS = {"leven", 109}
